{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a40097eb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch, os, glob, gc\n",
    "os.environ['VXM_BACKEND'] = 'pytorch'\n",
    "os.environ['NEURITE_BACKEND'] = 'pytorch'\n",
    "from torch.autograd import Variable\n",
    "from scipy import ndimage\n",
    "import enum\n",
    "import torchvision\n",
    "import math, random\n",
    "import nibabel as nib\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset\n",
    "from voxelmorph import voxelmorph\n",
    "import numpy as np\n",
    "import time\n",
    "import lagomorph as lm\n",
    "from lagomorph import adjrep \n",
    "from lagomorph import deform \n",
    "import SimpleITK as sitk\n",
    "# from atlas_builder.atlas_trainer import main_atlas_train\n",
    "from SADIR_forward import get_diffused_image, get_deformed_image\n",
    "\n",
    "IMAGE_SIZE=64\n",
    "device = torch.device('cuda')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4d0cc61",
   "metadata": {},
   "source": [
    "### Data Loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01e3ea33",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Load Data\n",
    "class SADIRData(Dataset):\n",
    "    def __init__(self, path, test_flag=False):\n",
    "        self.test_flag = test_flag\n",
    "        if self.test_flag==True:\n",
    "            self.path = path + \"/test_64/\"\n",
    "            self.filenames = os.listdir(self.path)\n",
    "        else:\n",
    "            self.path = path + \"/train_64/\"\n",
    "            self.filenames = os.listdir(self.path)\n",
    "        self.filenames = [i for i in self.filenames if i.startswith('.')==False]\n",
    "        print(\"number of files in directory:\", len(os.listdir(self.path)))\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.filenames)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        atlas_path = './final_atlas.nii'\n",
    "        if self.test_flag:\n",
    "            obj= np.array(nib.load(os.path.join(self.path, self.filenames[idx])).get_fdata())\n",
    "        else:\n",
    "            obj= np.array(nib.load(os.path.join(self.path, self.filenames[idx])).get_fdata())\n",
    "        atlas= np.array(nib.load(atlas_path).get_fdata())\n",
    "        atlas[atlas<0.5]=0\n",
    "        atlas[atlas>=0.5]=1\n",
    "        img= np.zeros((IMAGE_SIZE,IMAGE_SIZE,IMAGE_SIZE))\n",
    "        for idi in list(np.linspace(0,62,14).astype(np.int16))[3:-1]:\n",
    "        \n",
    "            img[:,idi,:]=obj[:,idi,:]\n",
    "        \n",
    "        mask_bf = np.zeros((2, int(np.shape(img)[0]), int(np.shape(img)[1]), int(np.shape(img)[2])))\n",
    "        mask_bf[0, :, :, :] = img\n",
    "        mask_bf[1, :, :, :] = atlas\n",
    "\n",
    "        if self.test_flag:\n",
    "            return torch.from_numpy(mask_bf).float(), self.filenames[idx]\n",
    "        else:\n",
    "            obj = obj[np.newaxis, :, :, :]\n",
    "            return torch.from_numpy(mask_bf).float(), torch.from_numpy(obj).float()\n",
    "        \n",
    "args={'data_dir' : './',\n",
    "      'batch_size' : 2}\n",
    "\n",
    "ds = SADIRData(args['data_dir'], test_flag=True)\n",
    "datal= torch.utils.data.DataLoader(\n",
    "    ds,\n",
    "    batch_size=args['batch_size'],\n",
    "    shuffle=True)\n",
    "data = iter(datal)\n",
    "print(\"number of files: \", len(list(datal)))\n",
    "temp = torch.cuda.FloatTensor(args['batch_size'], 3, IMAGE_SIZE,IMAGE_SIZE,IMAGE_SIZE).fill_(0).contiguous()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6124863b",
   "metadata": {},
   "source": [
    "### Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d044e9ae",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\")\n",
    "model = voxelmorph.torch.networks.VxmDense.load('./trained_models/0510.pt', device)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e771ee2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "gt_path = './test_64/'\n",
    "data = iter(datal)\n",
    "for _ in range(len(list(datal))):\n",
    "    prior, fname = next(data)\n",
    "    time_ = random.randint(10,999)\n",
    "    x_t = torch.randn_like(prior[0][0].unsqueeze(0).unsqueeze(0))\n",
    "    for t_ in range(time_):\n",
    "        inputs = torch.cat([prior, x_t], dim=1)\n",
    "        inputs = [d.cuda().permute(0, 1, 2, 3, 4) for d in inputs.unsqueeze(0)]\n",
    "        # run inputs through the model to produce a warped image and flow field\n",
    "        m0_pred = model(*inputs, torch.tensor([t_]).cuda())        \n",
    "        x0_pred= [get_deformed_image(m0_pred, prior[0][1].unsqueeze(0).unsqueeze(0).cuda()).squeeze()]\n",
    "        x0_pred_prc= x0_pred[0].unsqueeze(0).unsqueeze(0)\n",
    "        x_t = x0_pred_prc\n",
    "    yim= nib.load(gt_path+fname[0])    \n",
    "    x0_pred_prc=x0_pred_prc.detach().cpu().numpy().squeeze()\n",
    "    k = (np.amax(x0_pred_prc) + np.amin(x0_pred_prc))/2\n",
    "    x0_pred_prc[x0_pred_prc>=k]=1\n",
    "    x0_pred_prc[x0_pred_prc<k]=0\n",
    "    nib.save(nib.Nifti1Image(x0_pred_prc, yim.affine,yim.header), './predictions_/'+str(fname[0]))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8",
   "language": "python",
   "name": "python-3.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
